{
 "cells": [
  {
   "cell_type": "raw",
   "id": "abe47592-909c-4844-bf44-9e55c2fb4bfa",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_position: 1\n",
    "title: RAG\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91c5ef3d",
   "metadata": {},
   "source": [
    "Let's look at adding in a retrieval step to a prompt and LLM, which adds up to a \"retrieval-augmented generation\" chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f25d9e9-d192-42e9-af50-5660a4bfb0d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install langchain openai faiss-cpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33be32af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import itemgetter\n",
    "\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.embeddings import OpenAIEmbeddings\n",
    "from langchain.schema.output_parser import StrOutputParser\n",
    "from langchain.schema.runnable import RunnablePassthrough\n",
    "from langchain.vectorstores import FAISS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bfc47ec1",
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorstore = FAISS.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
    "retriever = vectorstore.as_retriever()\n",
    "\n",
    "template = \"\"\"Answer the question based only on the following context:\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "model = ChatOpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eae31755",
   "metadata": {},
   "outputs": [],
   "source": [
    "chain = (\n",
    "    {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
    "    | prompt \n",
    "    | model \n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f3040b0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Harrison worked at Kensho.'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(\"where did harrison work?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e1d20c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Answer the question based only on the following context:\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "\n",
    "Answer in the following language: {language}\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "chain = {\n",
    "    \"context\": itemgetter(\"question\") | retriever, \n",
    "    \"question\": itemgetter(\"question\"), \n",
    "    \"language\": itemgetter(\"language\")\n",
    "} | prompt | model | StrOutputParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7ee8b2d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Harrison ha lavorato a Kensho.'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f007669c",
   "metadata": {},
   "source": [
    "## Conversational Retrieval Chain\n",
    "\n",
    "We can easily add in conversation history. This primarily means adding in chat_message_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3f30c348",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.schema.runnable import RunnableMap\n",
    "from langchain.schema import format_document"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "64ab1dbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.prompts.prompt import PromptTemplate\n",
    "\n",
    "_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.\n",
    "\n",
    "Chat History:\n",
    "{chat_history}\n",
    "Follow Up Input: {question}\n",
    "Standalone question:\"\"\"\n",
    "CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7d628c97",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Answer the question based only on the following context:\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "\"\"\"\n",
    "ANSWER_PROMPT = ChatPromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f60a5d0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template=\"{page_content}\")\n",
    "def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator=\"\\n\\n\"):\n",
    "    doc_strings = [format_document(doc, document_prompt) for doc in docs]\n",
    "    return document_separator.join(doc_strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7d007db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple, List\n",
    "def _format_chat_history(chat_history: List[Tuple]) -> str:\n",
    "    buffer = \"\"\n",
    "    for dialogue_turn in chat_history:\n",
    "        human = \"Human: \" + dialogue_turn[0]\n",
    "        ai = \"Assistant: \" + dialogue_turn[1]\n",
    "        buffer += \"\\n\" + \"\\n\".join([human, ai])\n",
    "    return buffer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5c32cc89",
   "metadata": {},
   "outputs": [],
   "source": [
    "_inputs = RunnableMap(\n",
    "    {\n",
    "        \"standalone_question\": {\n",
    "            \"question\": lambda x: x[\"question\"],\n",
    "            \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
    "        } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
    "    }\n",
    ")\n",
    "_context = {\n",
    "    \"context\": itemgetter(\"standalone_question\") | retriever | _combine_documents,\n",
    "    \"question\": lambda x: x[\"standalone_question\"]\n",
    "}\n",
    "conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "135c8205",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='Harrison was employed at Kensho.', additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conversational_qa_chain.invoke({\n",
    "    \"question\": \"where did harrison work?\",\n",
    "    \"chat_history\": [],\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "424e7e7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='Harrison worked at Kensho.', additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conversational_qa_chain.invoke({\n",
    "    \"question\": \"where did he work?\",\n",
    "    \"chat_history\": [(\"Who wrote this notebook?\", \"Harrison\")],\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5543183",
   "metadata": {},
   "source": [
    "### With Memory and returning source documents\n",
    "\n",
    "This shows how to use memory with the above. For memory, we need to manage that outside at the memory. For returning the retrieved documents, we just need to pass them through all the way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e31dd17c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.memory import ConversationBufferMemory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d4bffe94",
   "metadata": {},
   "outputs": [],
   "source": [
    "memory = ConversationBufferMemory(return_messages=True, output_key=\"answer\", input_key=\"question\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "733be985",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First we add a step to load memory\n",
    "# This needs to be a RunnableMap because its the first input\n",
    "loaded_memory = RunnableMap(\n",
    "    {\n",
    "        \"question\": itemgetter(\"question\"),\n",
    "        \"memory\": memory.load_memory_variables,\n",
    "    }\n",
    ")\n",
    "# Next we add a step to expand memory into the variables\n",
    "expanded_memory = {\n",
    "    \"question\": itemgetter(\"question\"),\n",
    "    \"chat_history\": lambda x: x[\"memory\"][\"history\"]\n",
    "}\n",
    "\n",
    "# Now we calculate the standalone question\n",
    "standalone_question = {\n",
    "    \"standalone_question\": {\n",
    "        \"question\": lambda x: x[\"question\"],\n",
    "        \"chat_history\": lambda x: _format_chat_history(x['chat_history'])\n",
    "    } | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),\n",
    "}\n",
    "# Now we retrieve the documents\n",
    "retrieved_documents = {\n",
    "    \"docs\": itemgetter(\"standalone_question\") | retriever,\n",
    "    \"question\": lambda x: x[\"standalone_question\"]\n",
    "}\n",
    "# Now we construct the inputs for the final prompt\n",
    "final_inputs = {\n",
    "    \"context\": lambda x: _combine_documents(x[\"docs\"]),\n",
    "    \"question\": itemgetter(\"question\")\n",
    "}\n",
    "# And finally, we do the part that returns the answers\n",
    "answer = {\n",
    "    \"answer\": final_inputs | ANSWER_PROMPT | ChatOpenAI(),\n",
    "    \"docs\": itemgetter(\"docs\"),\n",
    "}\n",
    "# And now we put it all together!\n",
    "final_chain = loaded_memory | expanded_memory | standalone_question | retrieved_documents | answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "806e390c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'answer': AIMessage(content='Harrison was employed at Kensho.', additional_kwargs={}, example=False),\n",
       " 'docs': [Document(page_content='harrison worked at kensho', metadata={})]}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = {\"question\": \"where did harrison work?\"}\n",
    "result = final_chain.invoke(inputs)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "977399fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note that the memory does not save automatically\n",
    "# This will be improved in the future\n",
    "# For now you need to save it yourself\n",
    "memory.save_context(inputs, {\"answer\": result[\"answer\"].content})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f94f7de4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'history': [HumanMessage(content='where did harrison work?', additional_kwargs={}, example=False),\n",
       "  AIMessage(content='Harrison was employed at Kensho.', additional_kwargs={}, example=False)]}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "memory.load_memory_variables({})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "poetry-venv",
   "language": "python",
   "name": "poetry-venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
